/*
 * Copyright (c) 2022 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
import Unsafe from './Unsafe';
import Constants from './Constants';
import { FiniteStateEntropy, Table } from './FiniteStateEntropy'
import FseTableReader from './FseTableReader'
import NumberTransform from './NumberTransform'
import { BitInputStream, Initializer, Loader } from './BitInputStream'
import Long from '../util/long/index'
import Arrays from '../util/Arrays'
import Util from './Util'

export default class Huffman {
    public static MAX_SYMBOL: number = 255;
    public static MAX_SYMBOL_COUNT: number = Huffman.MAX_SYMBOL + 1;
    public static MAX_TABLE_LOG: number = 12;
    public static MIN_TABLE_LOG: number = 5;
    public static MAX_FSE_TABLE_LOG: number = 6;

    // stats
    private weights: Int8Array = new Int8Array(Huffman.MAX_SYMBOL + 1);
    private ranks: Int32Array = new Int32Array(Huffman.MAX_TABLE_LOG + 1);

    // table
    private tableLog: number = -1;
    private symbols: Int8Array = new Int8Array(1 << Huffman.MAX_TABLE_LOG);
    private numbersOfBits: Int8Array = new Int8Array(1 << Huffman.MAX_TABLE_LOG);
    private reader: FseTableReader = new FseTableReader();
    private fseTable: Table = Table.getTables(Huffman.MAX_FSE_TABLE_LOG);

    public isLoaded(): boolean {
        return this.tableLog != -1;
    }

    public readTable(inputBase: any, inputAddress: Long, size: number): number {
        Arrays.fill(this.ranks, 0);
        let input: Long = inputAddress;

        // read table header
        Util.verify(size > 0, input, "Not enough input bytes");
        let inputSize: number = Unsafe.getByte(inputBase, input.toNumber()) & 0xFF;
        input = input.add(1)

        let outputSize: number;
        if (inputSize >= 128) {
            outputSize = inputSize - 127;
            inputSize = ((outputSize + 1) / 2);

            Util.verify(inputSize + 1 <= size, input, "Not enough input bytes");
            Util.verify(outputSize <= Huffman.MAX_SYMBOL + 1, input, "Input is corrupted");

            for (let i: number = 0; i < outputSize; i += 2) {
                let value: number = Unsafe.getByte(inputBase, input.add(i).divide(2).toNumber()) & 0xFF;
                this.weights[i] = NumberTransform.toByte(value >>> 4);
                this.weights[i + 1] = NumberTransform.toByte(value & 0b1111);
            }
        }
        else {
            Util.verify(inputSize + 1 <= size, input, "Not enough input bytes");

            let inputLimit: Long = input.add(inputSize);
            input = input.add(this.reader.readFseTable(this.fseTable, inputBase, input, inputLimit, FiniteStateEntropy.MAX_SYMBOL, Huffman.MAX_FSE_TABLE_LOG));
            outputSize = FiniteStateEntropy.decompress(this.fseTable, inputBase, input, inputLimit, this.weights);
        }

        let totalWeight: number = 0;
        for (let i: number = 0; i < outputSize; i++) {
            this.ranks[this.weights[i]]++;
            totalWeight += (1 << this.weights[i]) >> 1; // TODO same as 1 << (weights[n] - 1)?
        }
        Util.verify(totalWeight != 0, input, "Input is corrupted");

        this.tableLog = Util.highestBit(totalWeight) + 1;
        Util.verify(this.tableLog <= Huffman.MAX_TABLE_LOG, input, "Input is corrupted");

        let total: number = 1 << this.tableLog;
        let rest: number = total - totalWeight;
        Util.verify(Util.isPowerOf2(rest), input, "Input is corrupted");

        let lastWeight: number = Util.highestBit(rest) + 1;

        this.weights[outputSize] = lastWeight;
        this.ranks[lastWeight]++;

        let numberOfSymbols: number = outputSize + 1;

        // populate table
        let nextRankStart: number = 0;
        for (let i: number = 1; i < this.tableLog + 1; ++i) {
            let current: number = nextRankStart;
            nextRankStart += this.ranks[i] << (i - 1);
            this.ranks[i] = current;
        }

        for (let n: number = 0; n < numberOfSymbols; n++) {
            let weight: number = this.weights[n];
            let length: number = (1 << weight) >> 1; // TODO: 1 << (weight - 1) ??

            let symbolnum: number = n;
            let numberOfBits: number = this.tableLog + 1 - weight;
            for (let i: number = this.ranks[weight]; i < this.ranks[weight] + length; i++) {
                this.symbols[i] = symbolnum;
                this.numbersOfBits[i] = numberOfBits;
            }
            this.ranks[weight] += length;
        }

        Util.verify(this.ranks[1] >= 2 && (this.ranks[1] & 1) == 0, input, "Input is corrupted");

        return inputSize + 1;
    }

    public decodeSingleStream(inputBase: any, inputAddress: Long, inputLimit: Long, outputBase: any, outputAddress: Long, outputLimit: Long): void{
        let initializer: Initializer = new Initializer(inputBase, inputAddress, inputLimit);
        initializer.initialize();

        let bits: Long = initializer.getBits();
        let bitsConsumed: number = initializer.getBitsConsumed();
        let currentAddress: Long = initializer.getCurrentAddress();

        let tableLog: number = this.tableLog;
        let numbersOfBits: Int8Array = this.numbersOfBits;
        let symbols: Int8Array = this.symbols;

        let output: Long = outputAddress;
        let fastOutputLimit: Long = outputLimit.sub(4);
        while (output.lessThan(fastOutputLimit)) {
            let loader: Loader = new Loader(inputBase, inputAddress, currentAddress, bits, bitsConsumed);
            let done: boolean = loader.load();
            bits = loader.getBits();
            bitsConsumed = loader.getBitsConsumed();
            currentAddress = loader.getCurrentAddress();
            if (done) {
                break;
            }

            bitsConsumed = Huffman.decodeSymbol(outputBase, output, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
            bitsConsumed = Huffman.decodeSymbol(outputBase, output.add(1), bits, bitsConsumed, tableLog, numbersOfBits, symbols);
            bitsConsumed = Huffman.decodeSymbol(outputBase, output.add(2), bits, bitsConsumed, tableLog, numbersOfBits, symbols);
            bitsConsumed = Huffman.decodeSymbol(outputBase, output.add(3), bits, bitsConsumed, tableLog, numbersOfBits, symbols);
            output = output.add(Constants.SIZE_OF_INT);
        }

        this.decodeTail(inputBase, inputAddress, currentAddress, bitsConsumed, bits, outputBase, output, outputLimit);
    }

    public decode4Streams(inputBase: any, inputAddress: Long, inputLimit: Long, outputBase: any, outputAddress: Long, outputLimit: Long): void
    {
        Util.verify(
        inputLimit.sub(inputAddress)
            .greaterThanOrEqual(10), inputAddress, "Input is corrupted"); // jump table + 1 byte per stream

        let start1: Long = inputAddress.add(3 * Constants.SIZE_OF_SHORT); // for the shorts we read below
        let start2: Long = start1.add(Unsafe.getShort(inputBase, inputAddress.toNumber()) & 0xFFFF);
        let start3: Long = start2.add(Unsafe.getShort(inputBase, inputAddress.add(2).toNumber()) & 0xFFFF);
        let start4: Long = start3.add(Unsafe.getShort(inputBase, inputAddress.add(4).toNumber()) & 0xFFFF);

        let initializer: Initializer = new Initializer(inputBase, start1, start2);
        initializer.initialize();
        let stream1bitsConsumed: number = initializer.getBitsConsumed();
        let stream1currentAddress: Long = initializer.getCurrentAddress();
        let stream1bits: Long = initializer.getBits();

        initializer = new Initializer(inputBase, start2, start3);
        initializer.initialize();
        let stream2bitsConsumed: number = initializer.getBitsConsumed();
        let stream2currentAddress: Long = initializer.getCurrentAddress();
        let stream2bits: Long = initializer.getBits();

        initializer = new Initializer(inputBase, start3, start4);
        initializer.initialize();
        let stream3bitsConsumed: number = initializer.getBitsConsumed();
        let stream3currentAddress: Long = initializer.getCurrentAddress();
        let stream3bits: Long = initializer.getBits();

        initializer = new Initializer(inputBase, start4, inputLimit);
        initializer.initialize();
        let stream4bitsConsumed: number = initializer.getBitsConsumed();
        let stream4currentAddress: Long = initializer.getCurrentAddress();
        let stream4bits: Long = initializer.getBits();

        let segmentSize: number = (outputLimit.sub(outputAddress).add(3)).divide(4).toInt();

        let outputStart2: Long = outputAddress.add(segmentSize);
        let outputStart3: Long = outputStart2.add(segmentSize);
        let outputStart4: Long = outputStart3.add(segmentSize);

        let output1: Long = outputAddress;
        let output2: Long = outputStart2;
        let output3: Long = outputStart3;
        let output4: Long = outputStart4;

        let fastOutputLimit: Long = outputLimit.sub(7);
        let tableLog: number = this.tableLog;
        let numbersOfBits: Int8Array = this.numbersOfBits;
        let symbols: Int8Array = this.symbols;

        while (output4.lessThan(fastOutputLimit)) {
            stream1bitsConsumed = Huffman.decodeSymbol(outputBase, output1, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
            stream2bitsConsumed = Huffman.decodeSymbol(outputBase, output2, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
            stream3bitsConsumed = Huffman.decodeSymbol(outputBase, output3, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
            stream4bitsConsumed = Huffman.decodeSymbol(outputBase, output4, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);

            stream1bitsConsumed = Huffman.decodeSymbol(outputBase, output1.add(1), stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
            stream2bitsConsumed = Huffman.decodeSymbol(outputBase, output2.add(1), stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
            stream3bitsConsumed = Huffman.decodeSymbol(outputBase, output3.add(1), stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
            stream4bitsConsumed = Huffman.decodeSymbol(outputBase, output4.add(1), stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);

            stream1bitsConsumed = Huffman.decodeSymbol(outputBase, output1.add(2), stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
            stream2bitsConsumed = Huffman.decodeSymbol(outputBase, output2.add(2), stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
            stream3bitsConsumed = Huffman.decodeSymbol(outputBase, output3.add(2), stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
            stream4bitsConsumed = Huffman.decodeSymbol(outputBase, output4.add(2), stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);

            stream1bitsConsumed = Huffman.decodeSymbol(outputBase, output1.add(3), stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
            stream2bitsConsumed = Huffman.decodeSymbol(outputBase, output2.add(3), stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
            stream3bitsConsumed = Huffman.decodeSymbol(outputBase, output3.add(3), stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
            stream4bitsConsumed = Huffman.decodeSymbol(outputBase, output4.add(3), stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);

            output1 = output1.add(Constants.SIZE_OF_INT);
            output2 = output2.add(Constants.SIZE_OF_INT);
            output3 = output3.add(Constants.SIZE_OF_INT);
            output4 = output4.add(Constants.SIZE_OF_INT);

            let loader: Loader = new Loader(inputBase, start1, stream1currentAddress, stream1bits, stream1bitsConsumed);
            let done: boolean = loader.load();
            stream1bitsConsumed = loader.getBitsConsumed();
            stream1bits = loader.getBits();
            stream1currentAddress = loader.getCurrentAddress();

            if (done) {
                break;
            }

            loader = new Loader(inputBase, start2, stream2currentAddress, stream2bits, stream2bitsConsumed);
            done = loader.load();
            stream2bitsConsumed = loader.getBitsConsumed();
            stream2bits = loader.getBits();
            stream2currentAddress = loader.getCurrentAddress();

            if (done) {
                break;
            }

            loader = new Loader(inputBase, start3, stream3currentAddress, stream3bits, stream3bitsConsumed);
            done = loader.load();
            stream3bitsConsumed = loader.getBitsConsumed();
            stream3bits = loader.getBits();
            stream3currentAddress = loader.getCurrentAddress();
            if (done) {
                break;
            }

            loader = new Loader(inputBase, start4, stream4currentAddress, stream4bits, stream4bitsConsumed);
            done = loader.load();
            stream4bitsConsumed = loader.getBitsConsumed();
            stream4bits = loader.getBits();
            stream4currentAddress = loader.getCurrentAddress();
            if (done) {
                break;
            }
        }

        Util.verify(output1 <= outputStart2 && output2 <= outputStart3 && output3 <= outputStart4, inputAddress, "Input is corrupted");

        this.decodeTail(inputBase, start1, stream1currentAddress, stream1bitsConsumed, stream1bits, outputBase, output1, outputStart2);
        this.decodeTail(inputBase, start2, stream2currentAddress, stream2bitsConsumed, stream2bits, outputBase, output2, outputStart3);
        this.decodeTail(inputBase, start3, stream3currentAddress, stream3bitsConsumed, stream3bits, outputBase, output3, outputStart4);
        this.decodeTail(inputBase, start4, stream4currentAddress, stream4bitsConsumed, stream4bits, outputBase, output4, outputLimit);
    }

    private decodeTail(inputBase: Object, startAddress: Long, currentAddress: Long, bitsConsumed: number, bits: Long, outputBase: Object, outputAddress: Long, outputLimit: Long): void
    {
        let tableLog: number = this.tableLog;
        let numbersOfBits: Int8Array = this.numbersOfBits;
        let symbols: Int8Array = this.symbols;

        while (outputAddress < outputLimit) {
            let loader: Loader = new Loader(inputBase, startAddress, currentAddress, bits, bitsConsumed);
            let done: boolean = loader.load();
            bitsConsumed = loader.getBitsConsumed();
            bits = loader.getBits();
            currentAddress = loader.getCurrentAddress();
            if (done) {
                break;
            }

            bitsConsumed = Huffman.decodeSymbol(outputBase, outputAddress, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
            outputAddress = outputAddress.add(1)
        }
        while (outputAddress < outputLimit) {
            bitsConsumed = Huffman.decodeSymbol(outputBase, outputAddress, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
            outputAddress = outputAddress.add(1)
        }

        Util.verify(BitInputStream.isEndOfStream(startAddress, currentAddress, bitsConsumed), startAddress, "Bit stream is not fully consumed");
    }

    private static decodeSymbol(outputBase: any, outputAddress: Long, bitContainer: Long, bitsConsumed: number, tableLog: number, numbersOfBits: Int8Array, symbols: Int8Array): number
    {
        let value: number = BitInputStream.peekBitsFast(bitsConsumed, bitContainer, tableLog).toInt();
        Unsafe.putByte(outputBase, outputAddress.toNumber(), symbols[value]);
        return bitsConsumed + numbersOfBits[value];
    }
}
